Frontmatter

If you are publishing this notebook on the web, you can set the parameters below to provide HTML metadata. This is useful for search engines and social media.

MNIST classification using Flux.jl

md"# MNIST classification using Flux.jl"
340 ms
begin
using Flux
using Plots
using Random
using Metrics
using MLUtils
using Statistics
using MLDatasets
using ProgressLogging

Random.seed!(9)
end
188 s

Loading MNIST dataset

md"### Loading MNIST dataset"
179 μs
begin
train_x, train_y = MLDatasets.MNIST(:train)[:]
size(train_x), size(train_y)
end
9.7 s
begin
test_x, test_y = MLDatasets.MNIST(:test)[:]
size(test_x), size(test_y)
end
540 ms

Preparing data

md"### Preparing data"
176 μs
ohb_train_y, ohb_test_y = Flux.onehotbatch(train_y, 0:9), Flux.onehotbatch(test_y, 0:9)
308 ms
train_data_loader
60-element DataLoader(::Tuple{Array{Float32, 3}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=1000)
  with first element:
  (28×28×1000 Array{Float32, 3}, 10×1000 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)
train_data_loader = DataLoader((train_x, ohb_train_y), batchsize=1000)
59.1 μs

Defining model

md"### Defining model"
183 μs
begin
model = Chain(
Flux.flatten,
Dense((train_x |> size)[1:2] |> prod => 64, relu),
Dense(64 => 32, relu),
Dense(32 => 10),
softmax
) |> gpu
optimizer = Adam()
end
The GPU function is being called but the GPU is not accessible. 
Defaulting back to the CPU. (No action is required if you want to run on the CPU).
7.3 s
Chain(
  Flux.flatten,
  Dense(784 => 64, relu),               # 50_240 parameters
  Dense(64 => 32, relu),                # 2_080 parameters
  Dense(32 => 10),                      # 330 parameters
  NNlib.softmax,
)                   # Total: 6 arrays, 52_650 parameters, 206.039 KiB.
20.6 μs
loss (generic function with 1 method)
loss(x, y) = Flux.crossentropy(model(x), y) |> gpu
1.3 ms

Training phase

md"### Training phase"
356 μs
epochs
20
epochs = 20
25.9 μs
begin
train_losses = Float32[]
test_losses = Float32[]

@progress for e in 1:epochs
gs = gradient(Flux.params(model)) do
l = loss(d...)
end
Flux.update!(optimizer, Flux.params(model), gs)
end
end
end
100% 113 s

Testing phase

md"### Testing phase"
220 μs
pred
10×10000 Matrix{Float32}:
 9.66887f-5   2.92489f-7   3.26758f-5   …  4.43695f-10  8.37695f-9   2.03662f-6
 6.7984f-7    5.18221f-5   0.972795        2.70698f-11  2.28261f-9   1.34618f-13
 0.000140351  0.999722     0.00476795      9.11259f-10  3.35669f-8   3.36648f-6
 0.000625461  0.000215399  0.00145327      3.37758f-9   8.10972f-9   6.33107f-11
 1.07602f-7   2.5243f-12   0.00272533      0.999866     2.65759f-8   1.29285f-5
 8.25873f-6   6.90715f-6   0.00280023   …  7.27103f-7   0.999874     1.3763f-7
 7.75011f-11  6.53439f-7   0.000996045     1.44187f-7   3.35018f-8   0.999981
 0.999013     2.01761f-8   0.00299345      2.38273f-5   9.20886f-10  1.57058f-11
 1.1349f-5    3.14942f-6   0.0105128       1.40182f-5   0.000124549  2.17979f-7
 0.000104043  2.43567f-13  0.000922991     9.56573f-5   9.57752f-7   3.24698f-8
pred = model(test_x)
34.1 ms
true_pred
true_pred = Flux.onecold(pred, 0:9)
406 ms
ohb_true_pred
10×10000 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  1  …  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  1  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  1  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  …  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅     ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1
 1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  1  ⋅  ⋅  1  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
ohb_true_pred = Flux.onehotbatch(true_pred, 0:9)
117 μs
0.9687
Metrics.categorical_accuracy(ohb_true_pred, ohb_test_y)
775 ms
begin
plot(1:epochs, train_losses, title="Loss over epochs: $epochs", label="Train")
plot!(1:epochs, test_losses, label="Test")
end
10.9 s